from torchvision import transforms
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.nn.functional as F
import dataset_processing
import torch.optim as optim
from torch.autograd import Variable
import torch
import torchvision.models as models
from efficientnet_pytorch import EfficientNet

import torchlib

# Use the torchvision's implementation of ResNeXt, but add FC layer for a different number of classes (27) and a Sigmoid instead of a default Softmax.
class resnext101(nn.Module):
    def __init__(self, nlabel):
        super().__init__()
        resnet = models.resnext101_32x8d(pretrained=True)
        resnet.fc = nn.Sequential(
            nn.Dropout(p=0.2),
            nn.Linear(in_features=resnet.fc.in_features, out_features=nlabel)
        )
        self.base_model = resnet
        self.sigm = nn.Sigmoid()
    
    def forward(self, x):
        return self.sigm(self.base_model(x))


# Use the torchvision's implementation of Densenet, but add FC layer for a different number of classes (27) and a Sigmoid instead of a default Softmax.
class densenet161(nn.Module):
    def __init__(self, nlabel):
        super().__init__()
        densenet = models.densenet161(pretrained=True)
        densenet.classifier = nn.Sequential(
            nn.Dropout(p=0.2),
            nn.Linear(in_features=densenet.classifier.in_features, out_features=nlabel)
        )
        self.base_model = densenet
        self.sigm = nn.Sigmoid()
    
    def forward(self, x):
        return self.sigm(self.base_model(x))
